from datasets import load_dataset, Dataset

# 提示词模板
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""


# 根据数据集文本格式提取答案
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()


def get_gsm8k_dataset(dataset_path: str, split: str = "train") -> Dataset:
    dataset = load_dataset(dataset_path)[split]
    dataset = dataset.map(
        lambda x: {
            "prompt": [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": x["question"]},
            ],
            "answer": extract_hash_answer(x["answer"]),
        }
    )
    return dataset


if __name__ == "__main__":
    dataset = get_gsm8k_dataset()
    print(len(dataset))
